import matplotlib.pyplot as plt
import numpy as np # type: ignore
import pandas as pd # type: ignore


def plot_diff_cases(cases, plot_labels, save_name):
    color_dict = {'_st': 'black',
                '_sv': 'tab:orange',
                '':'tab:purple',
                '_k_x_2': 'peachpuff',
                '_k_x_3': 'lightsalmon',
                '_k_div_3': 'lightblue',
                '_cont':'steelblue'}
    
    raw_bm_box_df = pd.read_csv('comsol_raw/on_bm_box.csv', skiprows = 8, index_col = [0,1,2])
    raw_ooc_box_df = pd.read_csv('comsol_raw/on_ooc_box.csv', skiprows = 8, index_col = [0,1,2])
    
    freq_strs = list(raw_bm_box_df.columns.get_level_values(0))
    freqs = list(sorted(set([float(i.split('freq=')[1]) for i in freq_strs])))
    freqs = freqs[:7]

    new_box_df = pd.DataFrame(index = freqs, columns = ['bm_mag_nm','ooc_mag_nm','bm_phase_rad','ooc_phase_rad'])
    new_box_df.index.name = "freq"
    
    new_slice_df = pd.DataFrame(index = pd.MultiIndex.from_product([cases,freqs], names = ['stimul_loc','freq']), columns = ['bm_mag_nm','ooc_mag_nm','bm_phase_rad','ooc_phase_rad'])        
    fig, axarr = plt.subplots(2, 3, sharex='col', sharey='row', figsize = [8,5]) 
        
    for diri, dir in zip(range(3), ['Z','X','Y']):
        for freqi in freqs:
            freqi = int(freqi)

            new_box_df.loc[freqi, 'bm_mag_nm'] = raw_bm_box_df.loc[:, 'solid.uAmp'+dir+' (mm) @ freq='+str(freqi)].iloc[0]*1e6/10
            new_box_df.loc[freqi, 'ooc_mag_nm'] = raw_ooc_box_df.loc[:, 'solid.uAmp'+dir+' (mm) @ freq='+str(freqi)].iloc[0]*1e6/10            
            
            new_box_df.loc[freqi, 'bm_phase_cycle'] = raw_bm_box_df.loc[:, 'solid.uPhase'+dir+' (rad) @ freq='+str(freqi)].iloc[0]/2/np.pi
            new_box_df.loc[freqi, 'ooc_phase_cycle'] = raw_ooc_box_df.loc[:, 'solid.uPhase'+dir+' (rad) @ freq='+str(freqi)].iloc[0]/2/np.pi
        
        for case in cases:
            raw_bm_slice_df = pd.read_csv('comsol_raw/on_bm_slice'+case+'.csv', skiprows = 8, index_col = [0,1,2])
            raw_ooc_slice_df = pd.read_csv('comsol_raw/on_ooc_slice'+case+'.csv', skiprows = 8, index_col = [0,1,2])

            for freqi in freqs:
                freqi = int(freqi)
        
                new_slice_df.loc[(case, freqi), 'bm_mag_nm'] = raw_bm_slice_df.loc[:, 'solid.uAmp'+dir+' (mm) @ freq='+str(freqi)].iloc[0]*1e6/10
                new_slice_df.loc[(case, freqi), 'ooc_mag_nm'] = raw_ooc_slice_df.loc[:, 'solid.uAmp'+dir+' (mm) @ freq='+str(freqi)].iloc[0]*1e6/10            
                
                new_slice_df.loc[(case, freqi), 'bm_phase_cycle'] = raw_bm_slice_df.loc[:, 'solid.uPhase'+dir+' (rad) @ freq='+str(freqi)].iloc[0]/2/np.pi
                new_slice_df.loc[(case, freqi), 'ooc_phase_cycle'] = raw_ooc_slice_df.loc[:, 'solid.uPhase'+dir+' (rad) @ freq='+str(freqi)].iloc[0]/2/np.pi

        ################################################   
        plot_freqs = np.asarray(freqs)/1000

        for rowi, t, m in zip(range(2),['ooc', 'bm'], ['o', 's']):
            axarr[rowi, diri].plot(plot_freqs, new_box_df.loc[:, t+'_mag_nm'], marker = m, label = 'full-length', color = 'tab:green')
            for case, plot_label in zip(cases, plot_labels):
                axarr[rowi, diri].plot(plot_freqs, new_slice_df.loc[(case, slice(None)), t+'_mag_nm'], marker = m, color = color_dict[case], linestyle = '--', label = plot_label)
                # axarr[rowi, diri].plot(plot_freqs, new_slice_df.loc[('_st', slice(None)), t+'_mag_nm'], marker = m, color = color_dict['_st'], linestyle = '--', label = '')
        
            axarr[rowi, diri].set_ylim(-0.5,20)
            axarr[rowi, diri].set_xticks([4,5,6,7])
            axarr[rowi, diri].set_yticks([0,5,10,15, 20])                
            
            axarr[rowi, diri].spines['right'].set_visible(False)
            axarr[rowi, diri].spines['top'].set_visible(False)

    fig.tight_layout()
    if '_cont' in cases and '_k_x_3' in cases:
        axarr[1,1].legend()
    else:
        axarr[0,1].legend()
    plt.savefig(save_name, transparent=True)
    plt.show()

def plot_study(study):
    if study == '_locs':
        cases = ['_st', '_sv', '']        
        plot_labels = ['slice: ST stml', 'slice: SV stml', 'slice: SV & ST stml']
    
    if study == '_ks':
        cases = ['_st', '_k_x_2', '_k_x_3', '_k_div_3', '_cont']
        plot_labels = ['slice: k', 'slice: 2*k', 'slice: 3*k', 'slice: k/3', 'slice: continuity']

    if study == '_ks_large':
        cases = ['_st', '_k_x_2', '_k_x_3']
        plot_labels = ['slice: k', 'slice: 2*k', 'slice: 3*k']
    
    if study == '_ks_small':
        cases = ['_st', '_k_div_3', '_cont']
        plot_labels = ['slice: k', 'slice: k/3', 'slice: continuity']

    
    save_name = 'manufigs/mag_phase_box_vs_slice' + study + '.pdf'
    plot_diff_cases(cases, plot_labels, save_name)

plot_study('_locs')
plot_study('_ks_large')
plot_study('_ks_small')

